{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bb9102c3-5b8d-4295-8f29-113b35ec5679",
   "metadata": {},
   "source": [
    "# 一、LLM 预训练"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8557a6a6-294a-49c3-a8f6-e58bc3bf443d",
   "metadata": {},
   "source": [
    "1.1 初始化 LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "25f1fad8-772c-474e-a43e-77623106485d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Qwen2Config {\n",
       "  \"_name_or_path\": \"autodl-tmp/qwen-1.5b\",\n",
       "  \"architectures\": [\n",
       "    \"Qwen2ForCausalLM\"\n",
       "  ],\n",
       "  \"attention_dropout\": 0.0,\n",
       "  \"bos_token_id\": 151643,\n",
       "  \"eos_token_id\": 151643,\n",
       "  \"hidden_act\": \"silu\",\n",
       "  \"hidden_size\": 1536,\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"intermediate_size\": 8960,\n",
       "  \"max_position_embeddings\": 131072,\n",
       "  \"max_window_layers\": 28,\n",
       "  \"model_type\": \"qwen2\",\n",
       "  \"num_attention_heads\": 12,\n",
       "  \"num_hidden_layers\": 28,\n",
       "  \"num_key_value_heads\": 2,\n",
       "  \"rms_norm_eps\": 1e-06,\n",
       "  \"rope_theta\": 1000000.0,\n",
       "  \"sliding_window\": null,\n",
       "  \"tie_word_embeddings\": true,\n",
       "  \"torch_dtype\": \"bfloat16\",\n",
       "  \"transformers_version\": \"4.44.2\",\n",
       "  \"use_cache\": true,\n",
       "  \"use_mrope\": false,\n",
       "  \"use_sliding_window\": false,\n",
       "  \"vocab_size\": 151936\n",
       "}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载定义好的模型参数-此处以 Qwen-2.5-1.5B 为例\n",
    "# 使用 transforemrs 的 Config 类进行加载\n",
    "from transformers import AutoConfig\n",
    "\n",
    "model_path = \"autodl-tmp/qwen-1.5b\"\n",
    "config = AutoConfig.from_pretrained(model_path)\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "82b075a1-4fe9-4abb-b5b4-769d1c1a7156",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training new model from scratch - Total size=1472.20M params\n"
     ]
    }
   ],
   "source": [
    "# 使用该配置生成一个定义好的模型\n",
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_config(config,trust_remote_code=True)\n",
    "model.to(\"cuda\")\n",
    "n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())\n",
    "print(f\"Training new model from scratch - Total size={n_params/2**20:.2f}M params\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e05ea707-23db-4e67-8b7d-e57d019887dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Qwen2ForCausalLM(\n",
       "  (model): Qwen2Model(\n",
       "    (embed_tokens): Embedding(151936, 1536)\n",
       "    (layers): ModuleList(\n",
       "      (0-27): 28 x Qwen2DecoderLayer(\n",
       "        (self_attn): Qwen2SdpaAttention(\n",
       "          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
       "          (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
       "          (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
       "          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
       "          (rotary_emb): Qwen2RotaryEmbedding()\n",
       "        )\n",
       "        (mlp): Qwen2MLP(\n",
       "          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
       "          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
       "          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
       "        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
       "      )\n",
       "    )\n",
       "    (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 看一下模型\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3408137b-eb50-4119-be1c-7a4ff951ab24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Qwen2TokenizerFast(name_or_path='autodl-tmp/qwen-1.5b', vocab_size=151643, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={\n",
       "\t151643: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151644: AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151645: AddedToken(\"<|im_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151646: AddedToken(\"<|object_ref_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151647: AddedToken(\"<|object_ref_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151648: AddedToken(\"<|box_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151649: AddedToken(\"<|box_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151650: AddedToken(\"<|quad_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151651: AddedToken(\"<|quad_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151652: AddedToken(\"<|vision_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151653: AddedToken(\"<|vision_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151654: AddedToken(\"<|vision_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151655: AddedToken(\"<|image_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151656: AddedToken(\"<|video_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151657: AddedToken(\"<tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151658: AddedToken(\"</tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151659: AddedToken(\"<|fim_prefix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151660: AddedToken(\"<|fim_middle|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151661: AddedToken(\"<|fim_suffix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151662: AddedToken(\"<|fim_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151663: AddedToken(\"<|repo_name|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151664: AddedToken(\"<|file_sep|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载一个预训练好的 tokenizer\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "221a0fe2-a244-4e73-b82c-6da255d710dd",
   "metadata": {},
   "source": [
    "1.2 预训练数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "936261a6-94cf-4cf3-842c-d3f1fde47a71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "66ae9baa159b424ea5f5bc8d05b9b567",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 加载预训练数据\n",
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset('json', data_files='autodl-tmp/dataset/pretrain_data/mobvoi_seq_monkey_general_open_corpus_small.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "068edbb9-cb3c-49b1-aaf9-67b97ddfc58c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': '在查处虚开增值税专用发票案件中，常常涉及进项留抵税额和税款损失的认定和处理。在计算税款损失时，要不要将进项留抵税额包括在内？\\n对此，实务中存在意见分歧。\\n有人主张归并，即计算税款损失时包括进项留抵税额；\\n有人主张剥离，即计算税款损失时剔除进项留抵税额。分析这个问题，需要确定进项留抵税额与税款损失之间是什么关系。\\n理清这二者之间的关系，首先需要了解增值税的概念和其抵扣机制。增值税是以商品（货物、服务等）在流转过程中产生的增值额作为计税依据而征收的一种流转税。为避免重复征税，在增值税中存在抵扣链条机制。\\n一般而言，交易上游企业缴纳的税额，交易下游企业可以对相应的税额进行抵扣。\\n对增值税一般纳税人来说，其购进货物、服务等取得增值税专用发票，发票上的税额是进项税额。\\n其出售货物、服务等，向购买方开具增值税专用发票，发票的税额是销项税额。\\n一般情况下，销项税额减去进项税额的金额是应纳税额，企业根据应纳税额按期申报纳税。\\n其次需要了解进项留抵税额的概念及产生原因。\\n在计算销项税额和进项税额的差额时，有时会出现负数，即当期进项税额大于当期销项税额。这个差额在当期未实现抵扣，为进项留抵税额，在以后纳税人有销项税额时再进行抵扣。\\n企业产生进项留抵税额的主要原因是其进项税额和销项税额时间上的不一致。\\n例如，企业前期集中采购货物和服务，投资大，销项税率低于进项税率等。\\n从税款抵扣的角度看，进项留抵税额只是购进的这部分进项税额参与到增值税应纳税额的计算过程中，但是其对应的进项税额抵扣还未真正实现，一般要等到其未来有相应的销项税额时，才能真正实现进项税额抵扣。\\n可见，进项留抵税额处于不确定状态，能否抵扣受到很多因素影响，例如企业经营中断，没有销项税额，这时进项留抵税额就无法实现抵扣。但如果企业按照税收政策规定申请进项留抵退税，进项税额抵扣就随之实现。\\n最后需要了解税款损失的概念。\\n税款损失，通常是指因虚开增值税专用发票，导致国家税款被骗或者流失的金额。关于税款损失，实务中有多种表述。\\n例如，北京大学法学院教授陈兴良曾谈到虚开行为本身不会造成国家税款损失，只有利用发票抵扣时才会造成国家税款损失。刘兵等编著的《虚开增值税专用发票案例司法观点和案例解析》一书中提到：“给国家税款造成损失的数额，实际上就是被骗取的国家税款在侦查终结以前无法追回的部分。”\\n赵清海与王家欣合著的《增值税专用发票虚开的判定与预防》一书中提到：“司法实践中，受票方用虚开的增值税专用发票予以抵扣的税款，从而导致受票方应纳税额的减少是法院所认定的国家税款流失的金额。”\\n从这些表述可见，税款损失应该是实际造成的损失，不应包括不确定的部分——进项留抵税额，进项留抵税额与税款损失之间不能直接画等号。\\n综上分析，进项留抵税额，只是使国家税款处于可能被抵扣的状态，还没有真正造成国家税款流失，一般情况下应将其从税款损失中剥离，特殊条件下将其归并入税款损失。\\n例如，当纳税人造假按照税收政策规定申请进项留抵税额退税后，有关税款损失将会从危险状态转化成危害结果，这时候要将有关进项留抵税额并入税款损失。\\n所以，在虚开增值税专用发票案件中，一般情况下，如果以纳税人的进项税额作为税款损失的计算基数，在对其进行行政处罚或刑事处罚时，应把进项留抵税额从税款损失中剔除，但纳税人申请进项留抵退税的除外。这样处理，把处罚与危害结果相对应，体现行政处罚法的过罚相当原则和刑法的罚当其罪原则。'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds[\"train\"][0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ef372a1f-e82f-4f5d-8495-f21f06b35635",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['text']\n"
     ]
    }
   ],
   "source": [
    "# 查看特征\n",
    "column_names = list(ds[\"train\"].features)\n",
    "print(column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1fa637f5-3b23-4a33-b19b-4c90d1815c39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "316489431b9e494eb8358a0d0048096f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset (num_proc=10):   0%|          | 0/100001 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 对数据集进行 tokenize\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    # 使用预先加载的 tokenizer 进行分词\n",
    "    output = tokenizer([item for item in examples[\"text\"]])\n",
    "    return output\n",
    "\n",
    "# 批量处理\n",
    "tokenized_datasets = ds.map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    num_proc=10,\n",
    "    remove_columns=column_names,\n",
    "    load_from_cache_file=True,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ec30197e-fe7f-4f0d-903c-663146421f58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'attention_mask'],\n",
       "        num_rows: 100001\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9ec5431b-e3cf-44e0-9260-479d984253e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预训练一般将文本拼接成固定长度的文本段\n",
    "from itertools import chain\n",
    "\n",
    "# 这里我们取块长为 2048\n",
    "block_size = 2048\n",
    "\n",
    "def group_texts(examples):\n",
    "    # 将文本段拼接起来\n",
    "    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n",
    "    # 计算拼起来的整体长度\n",
    "    total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
    "    # 如果长度太长，进行分块\n",
    "    if total_length >= block_size:\n",
    "        total_length = (total_length // block_size) * block_size\n",
    "    # Split by chunks of max_len.\n",
    "    result = {\n",
    "        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
    "        for k, t in concatenated_examples.items()\n",
    "    }\n",
    "    # print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))       \n",
    "    print(\"group texts input examples length%d after_group size%d\"%(len(examples['input_ids']),len(result[\"input_ids\"])))\n",
    "    result[\"labels\"] = result[\"input_ids\"].copy()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "38428a53-6ba6-429f-8c4b-0985579e726b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ae53ab8aaa0043418c2b7eb86f3d462b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Grouping texts in chunks of 2048 (num_proc=10):   0%|          | 0/100001 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "group texts input examples length10001 after_group size2752\n",
      "group texts input examples length10000 after_group size2817\n",
      "group texts input examples length10000 after_group size2820\n",
      "group texts input examples length10000 after_group size2817\n",
      "group texts input examples length10000 after_group size2787\n",
      "group texts input examples length10000 after_group size2797\n",
      "group texts input examples length10000 after_group size2800\n",
      "group texts input examples length10000 after_group size2835\n",
      "group texts input examples length10000 after_group size2778\n",
      "group texts input examples length10000 after_group size2825\n"
     ]
    }
   ],
   "source": [
    "# 批量处理\n",
    "lm_datasets = tokenized_datasets.map(\n",
    "    group_texts,\n",
    "    batched=True,\n",
    "    num_proc=10,\n",
    "    load_from_cache_file=True,\n",
    "    desc=f\"Grouping texts in chunks of {block_size}\",\n",
    "    batch_size = 40000,\n",
    ")\n",
    "train_dataset = lm_datasets[\"train\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b0dd8c8-fb1f-4af9-8af4-21285ba389c0",
   "metadata": {},
   "source": [
    "1.3 使用 Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e3e1a85e-fc28-4154-870e-f6a09f108059",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "# 配置训练参数\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"autodl-tmp/output/pretrain\",\n",
    "    per_device_train_batch_size=4,\n",
    "    gradient_accumulation_steps=4,\n",
    "    logging_steps=10,\n",
    "    num_train_epochs=1,\n",
    "    save_steps=100, \n",
    "    learning_rate=1e-4,\n",
    "    save_on_each_node=True,\n",
    "    gradient_checkpointing=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "62a97e46-ff06-4278-b318-e3e4da1b93d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
      "################################################################################\n",
      "WARNING!\n",
      "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
      "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
      "to learn more and leave feedback.\n",
      "################################################################################\n",
      "\n",
      "  deprecation_warning()\n"
     ]
    }
   ],
   "source": [
    "from transformers import Trainer, default_data_collator\n",
    "from torchdata.datapipes.iter import IterableWrapper\n",
    "\n",
    "# 训练器\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset= IterableWrapper(train_dataset),\n",
    "    eval_dataset= None,\n",
    "    tokenizer=tokenizer,\n",
    "    # 默认为 MLM 的 collator，使用 CLM 的 collater\n",
    "    data_collator=default_data_collator\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a929b11a-99f5-45fc-9f9a-05c0163204c3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start train\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
      "  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='101' max='1751' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 101/1751 29:31 < 8:12:11, 0.06 it/s, Epoch 0.06/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>10.987700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>9.160700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>8.352700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>8.159800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>8.042500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>8.014400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>7.986700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>7.951800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>7.875500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "RuntimeError",
     "evalue": "[enforce fail at inline_container.cc:603] . unexpected pos 6546708864 vs 6546708760",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:652\u001b[0m, in \u001b[0;36msave\u001b[0;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m    651\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _open_zipfile_writer(f) \u001b[38;5;28;01mas\u001b[39;00m opened_zipfile:\n\u001b[0;32m--> 652\u001b[0m     \u001b[43m_save\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopened_zipfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_module\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_protocol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_disable_byteorder_record\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    653\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:886\u001b[0m, in \u001b[0;36m_save\u001b[0;34m(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m    885\u001b[0m num_bytes \u001b[38;5;241m=\u001b[39m storage\u001b[38;5;241m.\u001b[39mnbytes()\n\u001b[0;32m--> 886\u001b[0m \u001b[43mzip_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_record\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstorage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_bytes\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at inline_container.cc:778] . PytorchStreamWriter failed writing file data/401: file write failed",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstart train\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m train_result \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:1938\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1936\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   1937\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1938\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1939\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1940\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1941\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1942\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1943\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2356\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2353\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m=\u001b[39m epoch \u001b[38;5;241m+\u001b[39m (step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m steps_skipped) \u001b[38;5;241m/\u001b[39m steps_in_epoch\n\u001b[1;32m   2354\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 2356\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2357\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   2358\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_substep_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2807\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2804\u001b[0m     metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluate(trial, ignore_keys_for_eval)\n\u001b[1;32m   2806\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[0;32m-> 2807\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetrics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetrics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2808\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_save(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2890\u001b[0m, in \u001b[0;36mTrainer._save_checkpoint\u001b[0;34m(self, model, trial, metrics)\u001b[0m\n\u001b[1;32m   2886\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_model(output_dir, _internal_call\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m   2888\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39msave_only_model:\n\u001b[1;32m   2889\u001b[0m     \u001b[38;5;66;03m# Save optimizer and scheduler\u001b[39;00m\n\u001b[0;32m-> 2890\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_optimizer_and_scheduler\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2891\u001b[0m     \u001b[38;5;66;03m# Save RNG state\u001b[39;00m\n\u001b[1;32m   2892\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_save_rng_state(output_dir)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:3006\u001b[0m, in \u001b[0;36mTrainer._save_optimizer_and_scheduler\u001b[0;34m(self, output_dir)\u001b[0m\n\u001b[1;32m   3001\u001b[0m     save_fsdp_optimizer(\n\u001b[1;32m   3002\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfsdp_plugin, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, output_dir\n\u001b[1;32m   3003\u001b[0m     )\n\u001b[1;32m   3004\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[1;32m   3005\u001b[0m     \u001b[38;5;66;03m# deepspeed.save_checkpoint above saves model/optim/sched\u001b[39;00m\n\u001b[0;32m-> 3006\u001b[0m     \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mOPTIMIZER_NAME\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3008\u001b[0m \u001b[38;5;66;03m# Save SCHEDULER & SCALER\u001b[39;00m\n\u001b[1;32m   3009\u001b[0m is_deepspeed_custom_scheduler \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_deepspeed_enabled \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m   3010\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlr_scheduler, DeepSpeedSchedulerWrapper\n\u001b[1;32m   3011\u001b[0m )\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:651\u001b[0m, in \u001b[0;36msave\u001b[0;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m    648\u001b[0m _check_save_filelike(f)\n\u001b[1;32m    650\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _use_new_zipfile_serialization:\n\u001b[0;32m--> 651\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m _open_zipfile_writer(f) \u001b[38;5;28;01mas\u001b[39;00m opened_zipfile:\n\u001b[1;32m    652\u001b[0m         _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)\n\u001b[1;32m    653\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:499\u001b[0m, in \u001b[0;36m_open_zipfile_writer_file.__exit__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m    498\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 499\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfile_like\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_end_of_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    500\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfile_stream \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    501\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfile_stream\u001b[38;5;241m.\u001b[39mclose()\n",
      "\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at inline_container.cc:603] . unexpected pos 6546708864 vs 6546708760"
     ]
    }
   ],
   "source": [
    "print('start train')\n",
    "train_result = trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ed2cd9-7169-4376-a26c-053918074761",
   "metadata": {},
   "source": [
    "# 二、模型 SFT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bb6e02b-c04c-45a4-b36c-904f9fedf61e",
   "metadata": {},
   "source": [
    "2.1 处理指令数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0d7cd012-fa2d-4c21-b6a5-c3830d12f59b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'conversations': [{'from': 'human',\n",
       "   'value': '针对健身房的新手，设计一套适合他们的健身器械使用指南，包括安全应用、正确姿势等方面。'},\n",
       "  {'from': 'assistant',\n",
       "   'value': '健身器械使用指南\\n1. 开始前，请先进行热身运动。这会帮助你的身体适应运动，并减少受伤的风险。\\n2. 在使用健身器械前，确保你已经了解了其使用方法。请阅读说明书或咨询教练以获得正确的使用技巧。\\n3. 谨防过度使用或过度挑战你的身体。 如果你觉得有些动作太难或太重，请添加锻炼计划，以逐步提高动作难度。\\n4. 使用合适的装备。 确保你拥有合适的运动鞋和舒适的运动服。 不要在裸露的脚或短裤上进行重量训练。\\n5. 在健身器械上使用安全装置。 这些通常用于保护你的身体免受不当操作造成的损伤。 例如，重量训练中，你需要使用杠铃和负重时，一定要使用卡子来防止重量滑落。\\n6. 注意正确的姿势。 如果你的姿势是错误的，那么你的身体很容易被伤害到，你也可能无法获得最佳的锻炼效果。 至关重要的是，保持直立的身体，保持头部和颈部的稳定，并使用合适的重量。\\n7. 保持合理的呼吸方式。 无论何时进行训练，都必须保持正常呼吸。 当你需要用力时，呼气； 当你放松时，吸气。\\n8. 安全存放器械。 在使用健身器械后，你需要把它们归还给适当的位置，以便其他人可以使用它们。\\n总之，健身器械的正确使用是关键之一，如果不健康和不安全，它们将无法帮助您达到您所需的健康成果。 选择适当的训练计划，并为训练提供足够的时间，以备逐渐适应新方法。 对于任何问题，请向教练咨询。'}],\n",
       " 'id': '66182880'}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open(\"autodl-tmp/dataset/sft_data/BelleGroup/train_3.5M_CN.json\") as f:\n",
    "    lst = [json.loads(line) for line in f.readlines()]\n",
    "\n",
    "lst[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fc8c599-89e9-4c35-a011-d2e52a1a4d9c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Qwen2TokenizerFast(name_or_path='autodl-tmp/qwen-1.5b', vocab_size=151643, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={\n",
       "\t151643: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151644: AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151645: AddedToken(\"<|im_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151646: AddedToken(\"<|object_ref_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151647: AddedToken(\"<|object_ref_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151648: AddedToken(\"<|box_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151649: AddedToken(\"<|box_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151650: AddedToken(\"<|quad_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151651: AddedToken(\"<|quad_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151652: AddedToken(\"<|vision_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151653: AddedToken(\"<|vision_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151654: AddedToken(\"<|vision_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151655: AddedToken(\"<|image_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151656: AddedToken(\"<|video_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151657: AddedToken(\"<tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151658: AddedToken(\"</tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151659: AddedToken(\"<|fim_prefix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151660: AddedToken(\"<|fim_middle|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151661: AddedToken(\"<|fim_suffix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151662: AddedToken(\"<|fim_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151663: AddedToken(\"<|repo_name|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151664: AddedToken(\"<|file_sep|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载一个预训练好的 tokenizer\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_path = \"autodl-tmp/qwen-1.5b\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "46730b29-41c0-4295-81f2-913d069b4669",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# 指令文本处理\n",
    "# 参考：https://github.com/QwenLM/Qwen/blob/main/finetune.py\n",
    "def preprocess(sources, tokenizer, max_len, system_message: str = \"You are a helpful assistant.\"):\n",
    "    # prompt 模板\n",
    "    roles = {\"human\": \"<|im_start|>human\", \"assistant\": \"<|im_start|>assistant\"}\n",
    "\n",
    "    # 不同的 tokenizer 需要特别定义\n",
    "    # BOS\n",
    "    im_start = tokenizer(\"<|im_start|>\").input_ids\n",
    "    # EOS\n",
    "    im_end = tokenizer(\"<|im_end|>\").input_ids\n",
    "    # PAD\n",
    "    IGNORE_TOKEN_ID = tokenizer.pad_token_id\n",
    "    # 换行符\n",
    "    nl_tokens = tokenizer('\\n').input_ids\n",
    "    # 角色标识符\n",
    "    _system = tokenizer('system').input_ids + nl_tokens\n",
    "    _user = tokenizer('human').input_ids + nl_tokens\n",
    "    _assistant = tokenizer('assistant').input_ids + nl_tokens\n",
    "\n",
    "    # 拼接多轮对话\n",
    "    input_ids, targets = [], []\n",
    "    for i in tqdm(range(len(sources))):\n",
    "        source = sources[i]\n",
    "        # 从 user 开始\n",
    "        if source[0][\"from\"] != \"human\":\n",
    "            source = source[1:]\n",
    "        # 分别是输入和输出\n",
    "        input_id, target = [], []\n",
    "        # system: 【BOS】system\\nYou are a helpful assistant.【EOS】\\n\n",
    "        system = im_start + _system + tokenizer(system_message).input_ids + im_end + nl_tokens\n",
    "        input_id += system\n",
    "        # system 不需要拟合\n",
    "        target += im_start + [IGNORE_TOKEN_ID] * (len(system)-3) + im_end + nl_tokens\n",
    "        assert len(input_id) == len(target)\n",
    "        # 依次拼接\n",
    "        for j, sentence in enumerate(source):\n",
    "            role = roles[sentence[\"from\"]]\n",
    "            # user：<|im_start|>human\\ninstruction【EOS】\\n\n",
    "            # assistant：<|im_start|>assistant\\nresponse【EOS】\\n\n",
    "            _input_id = tokenizer(role).input_ids + nl_tokens + \\\n",
    "                tokenizer(sentence[\"value\"]).input_ids + im_end + nl_tokens\n",
    "            input_id += _input_id\n",
    "            if role == '<|im_start|>human':\n",
    "                # user 不需要拟合\n",
    "                _target = im_start + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + im_end + nl_tokens\n",
    "            elif role == '<|im_start|>assistant':\n",
    "                # assistant 需要拟合\n",
    "                _target = im_start + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \\\n",
    "                    _input_id[len(tokenizer(role).input_ids)+1:-2] + im_end + nl_tokens\n",
    "            else:\n",
    "                print(role)\n",
    "                raise NotImplementedError\n",
    "            target += _target\n",
    "        assert len(input_id) == len(target)\n",
    "        # 最后进行 PAD\n",
    "        input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))\n",
    "        target += [IGNORE_TOKEN_ID] * (max_len - len(target))\n",
    "        input_ids.append(input_id[:max_len])\n",
    "        targets.append(target[:max_len])\n",
    "    # print(input_ids)\n",
    "    input_ids = torch.tensor(input_ids)\n",
    "    targets = torch.tensor(targets)\n",
    "\n",
    "    return dict(\n",
    "        input_ids=input_ids,\n",
    "        labels=targets,\n",
    "        attention_mask=input_ids.ne(tokenizer.pad_token_id),\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "7b3576cb-04d7-448a-9bd1-07cb7b344e6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[151644,   8948,    198,  ..., 151643, 151643, 151643],\n",
       "         [151644,   8948,    198,  ..., 151643, 151643, 151643]]),\n",
       " 'labels': tensor([[151644, 151643, 151643,  ..., 151643, 151643, 151643],\n",
       "         [151644, 151643, 151643,  ..., 151643, 151643, 151643]]),\n",
       " 'attention_mask': tensor([[ True,  True,  True,  ..., False, False, False],\n",
       "         [ True,  True,  True,  ..., False, False, False]])}"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试一下\n",
    "preprocess([lst[0][\"conversations\"],lst[1][\"conversations\"]], tokenizer, 1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "63e01dcf-4de4-4470-97dd-3317ef1aa00b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 自定义一个 Dataset\n",
    "from torch.utils.data import Dataset\n",
    "from typing import Dict\n",
    "\n",
    "class SupervisedDataset(Dataset):\n",
    "\n",
    "    def __init__(self, raw_data, tokenizer, max_len: int):\n",
    "        super(SupervisedDataset, self).__init__()\n",
    "        # 加载并预处理数据\n",
    "        sources = [example[\"conversations\"] for example in raw_data[:10000]]\n",
    "        data_dict = preprocess(sources, tokenizer, max_len)\n",
    "\n",
    "        self.input_ids = data_dict[\"input_ids\"]\n",
    "        self.labels = data_dict[\"labels\"]\n",
    "        self.attention_mask = data_dict[\"attention_mask\"]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.input_ids)\n",
    "\n",
    "    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n",
    "        return dict(\n",
    "            input_ids=self.input_ids[i],\n",
    "            labels=self.labels[i],\n",
    "            attention_mask=self.attention_mask[i],\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "934316d3-098f-4889-9cb0-d234a630b194",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [00:08<00:00, 1235.98it/s]\n"
     ]
    }
   ],
   "source": [
    "train_ds = SupervisedDataset(lst, tokenizer=tokenizer, max_len=2048)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
